{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/dkisselev-zz/llm_engineering/blob/wk7/Week_7_Excersise_fine_tuned_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GHsssBgWM_l0"
      },
      "source": [
        "# Predict Product Prices\n",
        "\n",
        "Model evaluation and inference tuning\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HnwMdAP3IHad"
      },
      "source": [
        "## Libraries and configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MDyR63OTNUJ6"
      },
      "outputs": [],
      "source": [
        "!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
        "!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-yikV8pRBer9"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import re\n",
        "import math\n",
        "import numpy as np\n",
        "from google.colab import userdata\n",
        "from huggingface_hub import login\n",
        "import wandb\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
        "from datasets import load_dataset\n",
        "from peft import PeftModel\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uuTX-xonNeOK"
      },
      "outputs": [],
      "source": [
        "# Models\n",
        "\n",
        "# WB or HF location of artifacts\n",
        "ARTIFCAT_LOCATTION=\"HF\"\n",
        "\n",
        "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
        "\n",
        "PROJECT_NAME = \"pricer\"\n",
        "\n",
        "# RUN_NAME = \"2025-10-23_23.41.24\" # - Fine tuned 16 batches / 8 bit run\n",
        "# RUN_NAME = \"2025-10-25_05.02.00\" # - Fine tuned 4 batches / 4 bit / LoRA 64/128 / Gradient 8\n",
        "RUN_NAME = \"2024-09-13_13.04.39\" # Ed's model run\n",
        "\n",
        "# Hugging Face\n",
        "HF_USER = \"dkisselev\"\n",
        "\n",
        "if ARTIFCAT_LOCATTION==\"HF\":\n",
        "  PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
        "  # REVISION = None\n",
        "  REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n",
        "\n",
        "\n",
        "  # FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
        "\n",
        "  # Ed's model\n",
        "  FINETUNED_MODEL = f\"ed-donner/{PROJECT_RUN_NAME}\"\n",
        "else:\n",
        "  # Weights and Biases\n",
        "  WANDB_ENTITY = \"dkisselev\"\n",
        "  os.environ[\"WANDB_API_KEY\"]=userdata.get('WANDB_API_KEY')\n",
        "\n",
        "  MODEL_ARTIFACT_NAME = f\"model-{RUN_NAME}\"\n",
        "  REVISION_TAG=\"v22\"\n",
        "  WANDB_ARTIFACT_PATH = f\"{WANDB_ENTITY}/{PROJECT_NAME}/{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\"\n",
        "\n",
        "# Data set\n",
        "\n",
        "# DATASET_NAME = f\"{HF_USER}/pricer-data2\"\n",
        "DATASET_NAME = \"ed-donner/pricer-data\"\n",
        "\n",
        "# Hyperparameters for QLoRA\n",
        "QUANT_4_BIT = True\n",
        "K_SEARCH_LIMIT = 900\n",
        "\n",
        "# Used for writing to output in color\n",
        "GREEN = \"\\033[92m\"\n",
        "YELLOW = \"\\033[93m\"\n",
        "RED = \"\\033[91m\"\n",
        "BLUE = \"\\033[94m\"\n",
        "RESET = \"\\033[0m\"\n",
        "COLOR_MAP = {\"red\":RED, \"orange\": BLUE, \"green\": GREEN}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8JArT3QAQAjx"
      },
      "source": [
        "### Load Data\n",
        "\n",
        "Data is loaded from Huggin Face\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WyFPZeMcM88v"
      },
      "outputs": [],
      "source": [
        "# Log in to HuggingFace\n",
        "hf_token = userdata.get('HF_TOKEN')\n",
        "login(hf_token)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cvXVoJH8LS6u"
      },
      "outputs": [],
      "source": [
        "dataset = load_dataset(DATASET_NAME)\n",
        "train = dataset['train']\n",
        "test = dataset['test']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qJWQ0a3wZ0Bw"
      },
      "source": [
        "## Load Tokenizer and Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lAUAAcEC6ido"
      },
      "outputs": [],
      "source": [
        "# 4 or 8 but quantization\n",
        "if QUANT_4_BIT:\n",
        "  quant_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
        "    bnb_4bit_quant_type=\"nf4\"\n",
        "  )\n",
        "else:\n",
        "  quant_config = BitsAndBytesConfig(\n",
        "    load_in_8bit=True\n",
        "  )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OQy4pCk-dutf"
      },
      "outputs": [],
      "source": [
        "# Load model from w&b\n",
        "if ARTIFCAT_LOCATTION==\"WB\":\n",
        "  artifact = wandb.Api().artifact(WANDB_ARTIFACT_PATH, type='model')\n",
        "  artifact_dir = artifact.download() # Downloads to a local cache dir"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R_O04fKxMMT-"
      },
      "outputs": [],
      "source": [
        "# Load the Tokenizer and the Model\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
        "tokenizer.pad_token = tokenizer.eos_token\n",
        "tokenizer.padding_side = \"right\"\n",
        "\n",
        "base_model = AutoModelForCausalLM.from_pretrained(\n",
        "    BASE_MODEL,\n",
        "    quantization_config=quant_config,\n",
        "    device_map=\"auto\",\n",
        ")\n",
        "base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
        "\n",
        "if ARTIFCAT_LOCATTION==\"HF\":\n",
        "  # Load the fine-tuned model with PEFT\n",
        "  if REVISION:\n",
        "    fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n",
        "  else:\n",
        "    fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n",
        "else:\n",
        "  # Model at W&B\n",
        "  fine_tuned_model = PeftModel.from_pretrained(base_model, artifact_dir)\n",
        "\n",
        "print(f\"Memory footprint: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UObo1-RqaNnT"
      },
      "source": [
        "## Hyperparameter helpers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n4u27kbwlekE"
      },
      "outputs": [],
      "source": [
        "def calculate_weighted_price(prices, probabilities):\n",
        "    \"\"\"\n",
        "    Calculates a normalized weighted average price.\n",
        "\n",
        "    Args:\n",
        "        prices (list or np.array): A list of prices.\n",
        "        probabilities (list or np.array): A list of corresponding probabilities (or weights).\n",
        "    Returns:\n",
        "        float: The normalized weighted average price.\n",
        "    \"\"\"\n",
        "    # Convert lists to numpy arrays\n",
        "    prices_array = np.array(prices)\n",
        "    probs_array = np.array(probabilities)\n",
        "\n",
        "    # Total of the probabilities to use for normalization\n",
        "    total_prob = np.sum(probs_array)\n",
        "\n",
        "    # Catch zero\n",
        "    if total_prob == 0:\n",
        "        if len(prices_array) > 0:\n",
        "            return np.mean(prices_array)\n",
        "        else:\n",
        "            return 0.0\n",
        "\n",
        "    # Weighted avrage\n",
        "    weighted_price = np.average(prices_array, weights=probs_array)\n",
        "\n",
        "    return weighted_price"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ROjIbGuH0FWS"
      },
      "outputs": [],
      "source": [
        "def get_top_k_predictions(prompt, device=\"cuda\"):\n",
        "    \"\"\"\n",
        "    Gets the top K price/probability pairs from the model.\n",
        "\n",
        "    Returns:\n",
        "        (list, list): A tuple containing (prices, probabilities)\n",
        "    \"\"\"\n",
        "    set_seed(42)\n",
        "    inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
        "    attention_mask = torch.ones(inputs.shape, device=device)\n",
        "\n",
        "    with torch.no_grad():\n",
        "        outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
        "        next_token_logits = outputs.logits[:, -1, :].to('cpu')\n",
        "\n",
        "    next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
        "    top_prob, top_token_id = next_token_probs.topk(K_SEARCH_LIMIT)\n",
        "\n",
        "    prices = []\n",
        "    probabilities = []\n",
        "\n",
        "    for i in range(K_SEARCH_LIMIT):\n",
        "      predicted_token = tokenizer.decode(top_token_id[0][i])\n",
        "      probability_tensor = top_prob[0][i]\n",
        "\n",
        "      try:\n",
        "        price = float(predicted_token)\n",
        "      except ValueError as e:\n",
        "        price = 0.0\n",
        "\n",
        "      if price > 0:\n",
        "        prices.append(price)\n",
        "        probabilities.append(probability_tensor.item())\n",
        "\n",
        "    if not prices:\n",
        "      return [], []\n",
        "\n",
        "    return prices, probabilities"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tnmTAiEG32xK"
      },
      "outputs": [],
      "source": [
        "def make_prompt(text):\n",
        "  if ARTIFCAT_LOCATTION==\"HF\":\n",
        "      return text\n",
        "  p_array = text.split(\"\\n\")\n",
        "  p_question = p_array[0].replace(\"How much does this cost to the nearest dollar?\",\"What is the price of this item?\")\n",
        "  p_title = p_array[2]\n",
        "  p_descr = re.sub(r'\\d', '', p_array[3])\n",
        "  p_price = p_array[5]\n",
        "  prompt =  p_title + \"\\n\" + p_descr + \"\\n\" + \"Question: \"+ p_question + \"\\n\\n\" + p_price\n",
        "  # prompt = p_array[0] + \"\\n\\n\\n\" + p_title + \"\\n\\n\" + p_descr + \"\\n\\n\" + p_price\n",
        "  # return text\n",
        "  return prompt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VNAEw5Eg4ABk"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "\n",
        "class Tester:\n",
        "\n",
        "    def __init__(self, predictor, data, title=None, size=250):\n",
        "        self.predictor = predictor\n",
        "        self.data = data\n",
        "        self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
        "        self.size = size\n",
        "        self.guesses = []\n",
        "        self.truths = []\n",
        "        self.errors = []\n",
        "        self.sles = []\n",
        "        self.colors = []\n",
        "\n",
        "    def color_for(self, error, truth):\n",
        "        if error<40 or error/truth < 0.2:\n",
        "            return \"green\"\n",
        "        elif error<80 or error/truth < 0.4:\n",
        "            return \"orange\"\n",
        "        else:\n",
        "            return \"red\"\n",
        "\n",
        "    def run_datapoint(self, i):\n",
        "        datapoint = self.data[i]\n",
        "\n",
        "        base_prompt = datapoint[\"text\"]\n",
        "        prompt = make_prompt(base_prompt)\n",
        "\n",
        "        guess = self.predictor(prompt)\n",
        "\n",
        "        # guess = self.predictor(datapoint[\"text\"])\n",
        "        truth = datapoint[\"price\"]\n",
        "        error = abs(guess - truth)\n",
        "        log_error = math.log(truth+1) - math.log(guess+1)\n",
        "        sle = log_error ** 2\n",
        "        color = self.color_for(error, truth)\n",
        "        title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n",
        "        self.guesses.append(guess)\n",
        "        self.truths.append(truth)\n",
        "        self.errors.append(error)\n",
        "        self.sles.append(sle)\n",
        "        self.colors.append(color)\n",
        "        print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
        "\n",
        "    def chart(self, title):\n",
        "        max_error = max(self.errors)\n",
        "        plt.figure(figsize=(12, 8))\n",
        "        max_val = max(max(self.truths), max(self.guesses))\n",
        "        plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n",
        "        plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
        "        plt.xlabel('Ground Truth')\n",
        "        plt.ylabel('Model Estimate')\n",
        "        plt.xlim(0, max_val)\n",
        "        plt.ylim(0, max_val)\n",
        "        plt.title(title)\n",
        "        plt.show()\n",
        "\n",
        "    def report(self):\n",
        "        average_error = sum(self.errors) / self.size\n",
        "        rmsle = math.sqrt(sum(self.sles) / self.size)\n",
        "        hits = sum(1 for color in self.colors if color==\"green\")\n",
        "        title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n",
        "        self.chart(title)\n",
        "\n",
        "    def run(self):\n",
        "        self.error = 0\n",
        "        for i in range(self.size):\n",
        "            self.run_datapoint(i)\n",
        "        self.report()\n",
        "\n",
        "    @classmethod\n",
        "    def test(cls, function, data):\n",
        "        cls(function, data).run()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dbWS1DPV4TPQ"
      },
      "outputs": [],
      "source": [
        "class Search_K:\n",
        "    \"\"\"\n",
        "    Search for the optimal 'k' value.\n",
        "    \"\"\"\n",
        "    def __init__(self, predictor, data, title=None, size=250):\n",
        "        self.predictor = predictor\n",
        "        self.data = data\n",
        "        self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
        "        self.size = size\n",
        "        self.truths = []\n",
        "\n",
        "        self.all_k_errors = []\n",
        "        self.max_k = K_SEARCH_LIMIT\n",
        "\n",
        "        # Store the list of probabilities for each inference\n",
        "        self.all_prob_lists = []\n",
        "        # Store the standard deviation of probs for each inference\n",
        "        self.prob_std_devs = []\n",
        "\n",
        "    def color_for(self, error, truth):\n",
        "        if error<40 or error/truth < 0.2:\n",
        "            return \"green\"\n",
        "        elif error<80 or error/truth < 0.4:\n",
        "            return \"orange\"\n",
        "        else:\n",
        "            return \"red\"\n",
        "\n",
        "    def run_datapoint(self, i):\n",
        "        datapoint = self.data[i]\n",
        "        base_prompt = datapoint[\"text\"]\n",
        "        prompt = make_prompt(base_prompt)\n",
        "        truth = datapoint[\"price\"]\n",
        "        self.truths.append(truth)\n",
        "\n",
        "        # Get the raw lists of prices and probabilities\n",
        "        prices, probabilities = self.predictor(prompt)\n",
        "\n",
        "        self.all_prob_lists.append(probabilities)\n",
        "\n",
        "        if probabilities:\n",
        "            # Calculate and store the spread (std dev) of this prob list\n",
        "            self.prob_std_devs.append(np.std(probabilities))\n",
        "        else:\n",
        "            # No probabilities, append 0 for spread\n",
        "            self.prob_std_devs.append(0.0)\n",
        "\n",
        "        errors_for_this_datapoint = []\n",
        "\n",
        "        if not prices:\n",
        "            print(f\"{i+1}: No valid prices found. Truth: ${truth:,.2f}.\")\n",
        "            error = np.abs(0 - truth)\n",
        "            errors_for_this_datapoint = [error] * self.max_k\n",
        "            self.all_k_errors.append(errors_for_this_datapoint)\n",
        "            return\n",
        "\n",
        "        # Iterate from k=1 up to max_k\n",
        "        for k in range(1, self.max_k + 1):\n",
        "            k_prices = prices[:k]\n",
        "            k_probabilities = probabilities[:k]\n",
        "\n",
        "            # Calculate the weighted price just for this k\n",
        "            guess = calculate_weighted_price(k_prices, k_probabilities)\n",
        "\n",
        "            # Calculate and store the error for this k\n",
        "            error = np.abs(guess - truth)\n",
        "            errors_for_this_datapoint.append(error)\n",
        "\n",
        "        # Store the list of errors (for k=1 to max_k)\n",
        "        self.all_k_errors.append(errors_for_this_datapoint)\n",
        "\n",
        "        # Print a summary for this datapoint\n",
        "        title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n",
        "\n",
        "        # Using [0], [19], [-1] for k=1, k=20, k=max_k (0-indexed)\n",
        "        k_1_err = errors_for_this_datapoint[0]\n",
        "        k_20_err = errors_for_this_datapoint[19]\n",
        "        k_max_err = errors_for_this_datapoint[-1]\n",
        "\n",
        "        color = self.color_for(k_1_err, truth)\n",
        "        print(f\"{COLOR_MAP[color]}{i+1}: Truth: ${truth:,.2f}. \"\n",
        "              f\"Errors (k=1, k=20, k={self.max_k}): \"\n",
        "              f\"(${k_1_err:,.2f}, ${k_20_err:,.2f}, ${k_max_err:,.2f}) \"\n",
        "              f\"Item: {title}{RESET}\")\n",
        "\n",
        "    def plot_k_vs_error(self, k_values, avg_errors_by_k, best_k, min_error):\n",
        "        \"\"\"\n",
        "        Plots the Average Error vs. k\n",
        "        \"\"\"\n",
        "        plt.figure(figsize=(12, 8))\n",
        "        plt.plot(k_values, avg_errors_by_k, label='Average Error vs. k')\n",
        "\n",
        "        # Highlight the best k\n",
        "        plt.axvline(x=best_k, color='red', linestyle='--',\n",
        "                    label=f'Best k = {best_k} (Avg Error: ${min_error:,.2f})')\n",
        "\n",
        "        plt.xlabel('Number of Top Probabilities/Prices (k)')\n",
        "        plt.ylabel('Average Absolute Error ($)')\n",
        "        plt.title(f'Optimal k Analysis for {self.title}')\n",
        "        plt.legend()\n",
        "        plt.grid(True, which='both', linestyle='--', linewidth=0.5)\n",
        "        # Set x-axis to start at 1\n",
        "        plt.xlim(left=1)\n",
        "        plt.savefig(\"k_vs_error_plot.png\")\n",
        "        plt.show()\n",
        "\n",
        "\n",
        "    def plot_probability_spread(self, idx_min_std, idx_med_std, idx_max_std):\n",
        "        probs_min = self.all_prob_lists[idx_min_std]\n",
        "        probs_med = self.all_prob_lists[idx_med_std]\n",
        "        probs_max = self.all_prob_lists[idx_max_std]\n",
        "        std_min = self.prob_std_devs[idx_min_std]\n",
        "        std_med = self.prob_std_devs[idx_med_std]\n",
        "        std_max = self.prob_std_devs[idx_max_std]\n",
        "\n",
        "        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 7), sharey=True)\n",
        "        fig.suptitle('Probability Distribution Spread Analysis (Examples)', fontsize=16)\n",
        "\n",
        "        def plot_strip(ax, probs, title):\n",
        "            if not probs:\n",
        "                ax.set_title(f\"{title}\\n(No probabilities found)\")\n",
        "                return\n",
        "            jitter = np.random.normal(0, 0.01, size=len(probs))\n",
        "            ax.scatter(jitter, probs, alpha=0.5, s=10) # Made points slightly larger\n",
        "            ax.set_title(title)\n",
        "            ax.set_xlabel(\"Jitter\")\n",
        "            ax.get_xaxis().set_ticks([])\n",
        "\n",
        "        plot_strip(ax1, probs_min,\n",
        "                   f'Inference {idx_min_std} (Lowest Spread)\\nStd Dev: {std_min:.6f}')\n",
        "        ax1.set_ylabel('Probability')\n",
        "        plot_strip(ax2, probs_med,\n",
        "                   f'Inference {idx_med_std} (Median Spread)\\nStd Dev: {std_med:.6f}')\n",
        "        plot_strip(ax3, probs_max,\n",
        "                   f'Inference {idx_max_std} (Highest Spread)\\nStd Dev: {std_max:.6f}')\n",
        "\n",
        "        plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
        "        plt.savefig(\"spread_examples_plot.png\")\n",
        "        plt.show()\n",
        "\n",
        "    def plot_all_std_devs(self):\n",
        "        \"\"\"\n",
        "        Plots a histogram and a line plot of the standard deviation\n",
        "        for ALL inferences.\n",
        "        \"\"\"\n",
        "        if not self.prob_std_devs:\n",
        "            print(\"No probability spreads recorded, skipping all-std plot.\")\n",
        "            return\n",
        "\n",
        "        # Create a figure with two subplots\n",
        "        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12))\n",
        "        fig.suptitle('Full Spread Analysis for All Inferences', fontsize=16)\n",
        "\n",
        "        # --- Plot Histogram ---\n",
        "        ax1.hist(self.prob_std_devs, bins=50, edgecolor='black')\n",
        "        ax1.set_title('Distribution of Probability Standard Deviations')\n",
        "        ax1.set_xlabel('Standard Deviation')\n",
        "        ax1.set_ylabel('Frequency (Number of Inferences)')\n",
        "\n",
        "        mean_std = np.mean(self.prob_std_devs)\n",
        "        ax1.axvline(mean_std, color='red', linestyle='--',\n",
        "                    label=f'Mean Std Dev: {mean_std:.6f}')\n",
        "        ax1.legend()\n",
        "\n",
        "        # --- Plot Line Plot ---\n",
        "        ax2.plot(self.prob_std_devs, marker='o', linestyle='-',\n",
        "                 markersize=3, alpha=0.7, label='Std Dev per Inference')\n",
        "        ax2.set_title('Probability Standard Deviation per Inference')\n",
        "        ax2.set_xlabel('Inference Index (0 to 249)')\n",
        "        ax2.set_ylabel('Standard Deviation')\n",
        "\n",
        "        ax2.axhline(mean_std, color='red', linestyle='--',\n",
        "                    label=f'Mean Std Dev: {mean_std:.6f}')\n",
        "        ax2.legend()\n",
        "        ax2.set_xlim(0, len(self.prob_std_devs) - 1)\n",
        "\n",
        "        plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
        "        plt.savefig(\"all_std_devs_plot.png\") # Save the plot\n",
        "        plt.show()\n",
        "\n",
        "    def report(self):\n",
        "        \"\"\"\n",
        "        Calls all three plotting functions.\n",
        "        \"\"\"\n",
        "        if not self.all_k_errors:\n",
        "             print(\"\\nNo data to report on. Exiting.\")\n",
        "             return\n",
        "\n",
        "        # Optimal k Analysis ---\n",
        "        errors_array = np.array(self.all_k_errors)\n",
        "        avg_errors_by_k = np.mean(errors_array, axis=0)\n",
        "        best_k_index = np.argmin(avg_errors_by_k)\n",
        "        min_error = avg_errors_by_k[best_k_index]\n",
        "        best_k = best_k_index + 1\n",
        "\n",
        "        print(\"\\n\" + \"=\"*40)\n",
        "        print(\"--- Optimal k Analysis Report ---\")\n",
        "        print(f\"Model: {self.title}\")\n",
        "        print(f\"Inferences Run: {self.size}\")\n",
        "        print(f\"Analyzed k from 1 to {self.max_k}\")\n",
        "        print(f\"===================================\")\n",
        "        print(f\"==> Best k: {best_k}\")\n",
        "        print(f\"==> Minimum Average Error: ${min_error:,.2f}\")\n",
        "        print(\"=\"*40 + \"\\n\")\n",
        "\n",
        "        k_values = np.arange(1, self.max_k + 1)\n",
        "        self.plot_k_vs_error(k_values, avg_errors_by_k, best_k, min_error)\n",
        "\n",
        "        # Probability Spread Analysis ---\n",
        "        if not self.prob_std_devs:\n",
        "            print(\"\\nNo probability spreads recorded, skipping spread plots.\")\n",
        "            return\n",
        "\n",
        "        print(\"\\n\" + \"=\"*40)\n",
        "        print(\"--- Probability Spread Analysis ---\")\n",
        "\n",
        "        # Find indices for examples\n",
        "        std_sorted_indices = np.argsort(self.prob_std_devs)\n",
        "        idx_min_std = std_sorted_indices[0]\n",
        "        idx_med_std = std_sorted_indices[len(std_sorted_indices) // 2]\n",
        "        idx_max_std = std_sorted_indices[-1]\n",
        "\n",
        "        print(f\"Lowest spread (std):  {self.prob_std_devs[idx_min_std]:.6f} (Inference {idx_min_std})\")\n",
        "        print(f\"Median spread (std): {self.prob_std_devs[idx_med_std]:.6f} (Inference {idx_med_std})\")\n",
        "        print(f\"Highest spread (std): {self.prob_std_devs[idx_max_std]:.6f} (Inference {idx_max_std})\")\n",
        "        print(\"=\"*40 + \"\\n\")\n",
        "\n",
        "        # Plot example spreads\n",
        "        self.plot_probability_spread(idx_min_std, idx_med_std, idx_max_std)\n",
        "\n",
        "        # Plot all spreads\n",
        "        self.plot_all_std_devs()\n",
        "\n",
        "        return best_k\n",
        "\n",
        "    def run(self):\n",
        "        for i in range(self.size):\n",
        "            self.run_datapoint(i)\n",
        "        best_k=self.report()\n",
        "        return best_k\n",
        "\n",
        "    @classmethod\n",
        "    def test(cls, function, data):\n",
        "        cls(function, data).run()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vtt13OuVE-t7"
      },
      "outputs": [],
      "source": [
        "# Search best K\n",
        "search_k = Search_K(get_top_k_predictions, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n",
        "best_k = search_k.run()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tuwYu1NYljIv"
      },
      "outputs": [],
      "source": [
        "top_K = best_k\n",
        "\n",
        "def improved_model_predict(prompt, device=\"cuda\"):\n",
        "    set_seed(42)\n",
        "    inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
        "    attention_mask = torch.ones(inputs.shape, device=device)\n",
        "\n",
        "    with torch.no_grad():\n",
        "        outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
        "        next_token_logits = outputs.logits[:, -1, :].to('cpu')\n",
        "\n",
        "    next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
        "    top_prob, top_token_id = next_token_probs.topk(top_K)\n",
        "\n",
        "    prices = []\n",
        "    # Renamed 'weights' to 'probabilities' for clarity\n",
        "    probabilities = []\n",
        "\n",
        "    for i in range(top_K):\n",
        "      predicted_token = tokenizer.decode(top_token_id[0][i])\n",
        "      # This is a torch.Tensor\n",
        "      probability_tensor = top_prob[0][i]\n",
        "\n",
        "      # print(predicted_token, probability_tensor)\n",
        "\n",
        "      try:\n",
        "        # Try to convert the decoded token string to a float\n",
        "        price = float(predicted_token)\n",
        "      except ValueError as e:\n",
        "        price = 0.0\n",
        "\n",
        "      # Only include valid, positive prices\n",
        "      if price > 0:\n",
        "        prices.append(price)\n",
        "        # We append the tensor to our list\n",
        "        probabilities.append(probability_tensor)\n",
        "\n",
        "    if not prices:\n",
        "      # If no valid prices were found, return 0.0\n",
        "      return 0.0\n",
        "\n",
        "\n",
        "    # Convert the list of prices to a numpy array\n",
        "    prices_np = np.array(prices)\n",
        "\n",
        "    # Convert the list of torch.Tensors to a numpy array of floats\n",
        "    probs_np = np.array([p.item() for p in probabilities])\n",
        "\n",
        "    # Calculate the normalized weighted average\n",
        "    final_price = np.average(prices_np, weights=probs_np)\n",
        "\n",
        "    return float(final_price) # Return as a standard python float"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3SxpLBJH70E-"
      },
      "outputs": [],
      "source": [
        "prompt=make_prompt(test[80]['text'])\n",
        "print(prompt)\n",
        "\n",
        "improved_model_predict(prompt)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W_KcLvyt6kbb"
      },
      "outputs": [],
      "source": [
        "# Run Estimate vs Ground Truth\n",
        "tester = Tester(improved_model_predict, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n",
        "tester.run()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "include_colab_link": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
